"""
2023-03-08
inference on 1975-2015 patentsview granted patent data using pretrained model PatentSBERTa
"""

import os
import json
import argparse
import torch
import logging
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Optional, Union
import pandas as pd
import numpy as np

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

global BASE_DIR
BASE_DIR = "/Volumes/Zihao_SSD2/"
torch.cuda.empty_cache()


def evaluate_specter_pl(
    model_name,
    eval_path,
    output_dir,
    batch_size=8,
    chunk_size=20000,
    embed_separate=False,
):
    """
    export CUDA_VISIBLE_DEVICES=
    python cli_model.py evaluate_specter_pl "data/specter/save/version_0/checkpoints/ep-epoch=3_avg_val_loss-avg_val_loss=0.136.ckpt" \
        ${SCIDOCS_DIR} ${BASE_DIR}/data/specter/embeds --method_name specter_pl --batch_size 16
    
    Args:
        model_name: PatentSBERTA, bert-base-uncased or specter_checkpoint 
        eval_path: path to patent raw data
        output_dir: output directory for embeddings
        batch_size: batch size for inference
        chunk_size: chunk size for reading patent_raw.csv
        embed_separate: if to embed patent titles and abstract separately for BERT and PatentSBERTA
        
    Returns:
    """
    # from specter_pl_train import Specter
    # from transformers import BertTokenizerFast
    from transformers import AutoTokenizer, AutoModel

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # env = get_env()
    if model_name == "patentsberta":
        model = AutoModel.from_pretrained("AI-Growth-Lab/PatentSBERTa")
        model = model.to(device)
        tokenizer = AutoTokenizer.from_pretrained("AI-Growth-Lab/PatentSBERTa")
    else:
        raise ValueError("model name is not correct.")

    if not os.path.exists(output_dir):
        logger.info(f"Create: {output_dir}")
        os.makedirs(output_dir)

    for ind in range(0, 291):
        print('Reading patent_raw.csv...')
        ds_metadata = pd.read_csv(
            eval_path, skiprows=chunk_size * ind, nrows=chunk_size, header=None
        )
        print(len(ds_metadata))
        ds_metadata.columns = [
            "patent_id",
            "patent_year",
            "patent_title",
            "patent_abstract",
        ]
        logger.info(f"processing rows: {chunk_size*ind} to {chunk_size*(ind+1)}")
        ds_metadata["patent_abstract"] = ds_metadata["patent_abstract"].fillna("")
        logger.info(f"embedding {ind}th chunk.")
        ds_docs = []
        ds_inputs = []

        for i, r in ds_metadata.iterrows():
            ds_docs.append(
                r["patent_title"] + tokenizer.sep_token + (r["patent_abstract"] or "")
            )

        logger.info(f"Tokenize text data")
        ds_inputs = tokenizer(
            ds_docs, padding=True, truncation=True, return_tensors="pt", max_length=512
        )

        logger.info(f"Predict text data")

        dl = DataLoader(
            DictOfListsDataset(ds_inputs), batch_size=batch_size, shuffle=False
        )
        embeds = []

        for batch_inputs in tqdm(dl, total=len(dl)):
            batch_inputs = {k: v.to(device) for k, v in batch_inputs.items()}

            model_out = model(**batch_inputs)
            embeds += model_out[0][:, 0, :].tolist()

        if len(embeds) != len(ds_metadata):
            raise ValueError(
                f"Invalid embeddings count: {len(embeds)} vs {len(ds_metadata)}"
            )

        out_fp = os.path.join(output_dir, f"patent_embeddings_{model_name}_{ind}.jsonl")

        logger.info(f"Write {model_name} {ind} jsonl embeddings to {out_fp}")

        with open(out_fp, "w") as f:
            for idx, r in ds_metadata.iterrows():
                f.write(
                    json.dumps(
                        {
                            "patent_id": r["patent_id"],
                            "embedding": embeds[idx],
                        }
                    )
                    + "\n"
                )

    logger.info("done.")


class DictOfListsDataset(Dataset):
    def __init__(self, samples: Dict[str, List]):
        self.samples = samples

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.samples.items()}

    def __len__(self):
        k = list(self.samples.keys())[0]  # first key

        return len(self.samples[k])


def main(
    model_name,
    eval_path: str,
    output_dir: str,
    embed_separate: bool,
    batch_size: int,
    chunk_size: int,
):
    """

    Example use:

    cd /Volumes/Zihao_SSD2/PatentsView 

    python3 code/pred_gender_patent.py \
        --model_name patentsberta \
        --eval_path patentsberta/patent_raw.csv \
        --output_dir patentsberta/patent_embedding_jsonl \
        --embed_separate False \
        --batch_size 8 \
        --chunk_size 20000
    """

    evaluate_specter_pl(
        model_name=model_name,
        eval_path=eval_path,
        output_dir=output_dir,
        embed_separate=embed_separate,
        batch_size=batch_size,
        chunk_size=chunk_size,
    )


if __name__ == "__main__":

    ap = argparse.ArgumentParser()
    ap.add_argument(
        "--model_name",
        choices=["patentsberta"],
        help="choose pretrained models",
    )
    ap.add_argument("--eval_path", help="path to evaluation directory")
    ap.add_argument("--output_dir", help="output directory to files")
    ap.add_argument("--embed_separate", default=False, help="separately embed title and abstract")
    ap.add_argument("--batch_size", default=8)
    ap.add_argument("--chunk_size", default=20000)
    args = ap.parse_args()

    logging.basicConfig(
        filename=os.path.join(BASE_DIR, "logs", f"embed_patent_{args.model_name}.log"),
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
    )

    logger = logging.getLogger(__name__)

    eval_dir = os.path.join(BASE_DIR, args.eval_path)

    output_dir = os.path.join(BASE_DIR, args.output_dir)

    main(
        args.model_name,
        eval_dir,
        output_dir,
        args.embed_separate,
        int(args.batch_size),
        int(args.chunk_size),
    )
